{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vKadZFQ2IdJb"
      },
      "source": [
        "# Post training an LLM for reasoning with GRPO in TRL\n",
        "\n",
        "_Authored by: [Sergio Paniego](https://github.com/sergiopaniego)_\n",
        "\n",
        "In this notebook, we'll guide you through the process of post-training a Large Language Model (LLM) using **Group Relative Policy Optimization (GRPO)**, a method introduced in the [DeepSeekMath paper](https://arxiv.org/abs/2402.03300). GRPO is particularly effective for **scaling test-time compute for extended reasoning**, making it an ideal approach for solving complex tasks, such as mathematical problem-solving.\n",
        "\n",
        "GRPO is a **reinforcement learning (RL) post-training technique** that was integrated into the training pipeline for [**DeepSeek-R1**](https://github.com/deepseek-ai/DeepSeek-R1). It seems to share similarities with the training procedures used in the latest [**OpenAI o1 and o3 models**](https://openai.com/index/learning-to-reason-with-llms/), though the exact alignment is not confirmed. Unlike earlier techniques that relied on search-heuristic methods, GRPO exclusively employs **RL** for post-training, enhancing the model's capacity to handle complex and nuanced tasks.\n",
        "\n",
        "\n",
        "\n",
        "The GRPO technique is available through the [TRL library](https://huggingface.co/docs/trl/main/en/grpo_trainer#quick-start). At the time of writing, the Hugging Face Science team is working to reproduce the full **DeepSeek-R1** training process, which you can explore in their [Open-R1 project](https://github.com/huggingface/open-r1). I highly recommend checking it out for a deeper dive into the overall process.\n",
        "\n",
        "In this notebook, we'll focus specifically on **post-training with GRPO**, though additional resources on DeepSeek-R1 and its training procedure are provided in the last section.\n",
        "\n",
        "Below is a diagram illustrating how this training procedure works.\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "![Image](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png)"
      ],
      "metadata": {
        "id": "RGLPwWtxsKQ-"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gSHmDKNFoqjC"
      },
      "source": [
        "## 1. Install Dependencies\n",
        "\n",
        "Let’s start by installing the essential libraries we’ll need for fine-tuning! 🚀\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GCMhPmFdIGSb"
      },
      "outputs": [],
      "source": [
        "!pip install  -U -q trl peft math_verify\n",
        "# Tested with transformers==4.47.1, trl==0.14.0, datasets==3.2.0, peft==0.14.0, accelerate==1.2.1, math_verify==0.3.3"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V0-2Lso6wkIh"
      },
      "source": [
        "Authenticate with your Hugging Face account to save and share your model directly from this notebook 🗝️."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xcL4-bwGIoaR"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import notebook_login\n",
        "\n",
        "notebook_login()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g9QXwbJ7ovM5"
      },
      "source": [
        "## 2. Load Dataset 📁\n",
        "\n",
        "These models excel at tasks that require **complex reasoning**. A prime example is **mathematical problem-solving**, which often demands multi-step reasoning to arrive at a correct solution.\n",
        "\n",
        "For this project, we'll use the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. This is a **reasoning-focused dataset** that contains mathematical problems, their solutions, and detailed reasoning steps that explain how to transition from the problem statement to the final solution.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QFe_A78aIwK8"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "dataset_id = 'AI-MO/NuminaMath-TIR'\n",
        "train_dataset, test_dataset = load_dataset(dataset_id, split=['train[:5%]', 'test[:5%]'])"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's check the structure of the dataset"
      ],
      "metadata": {
        "id": "Z52AYDWVWXFQ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "o2UKZj15jGwv",
        "outputId": "0a62027c-75a5-440e-8a15-1d8e49390478"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset({\n",
            "    features: ['problem', 'solution', 'messages'],\n",
            "    num_rows: 3622\n",
            "})\n"
          ]
        }
      ],
      "source": [
        "print(train_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's check one sample:"
      ],
      "metadata": {
        "id": "Enl47nVlWaiY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(train_dataset[0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5BglCVLLzY0S",
        "outputId": "2f8baf75-d0ee-4682-9a9d-9ea9901db326"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'problem': 'What is the coefficient of $x^2y^6$ in the expansion of $\\\\left(\\\\frac{3}{5}x-\\\\frac{y}{2}\\\\right)^8$?  Express your answer as a common fraction.', 'solution': \"To determine the coefficient of \\\\(x^2y^6\\\\) in the expansion of \\\\(\\\\left(\\\\frac{3}{5}x - \\\\frac{y}{2}\\\\right)^8\\\\), we can use the binomial theorem.\\n\\nThe binomial theorem states:\\n\\\\[\\n(a + b)^n = \\\\sum_{k=0}^{n} \\\\binom{n}{k} a^{n-k} b^k\\n\\\\]\\n\\nIn this case, \\\\(a = \\\\frac{3}{5}x\\\\), \\\\(b = -\\\\frac{y}{2}\\\\), and \\\\(n = 8\\\\).\\n\\nWe are interested in the term that contains \\\\(x^2y^6\\\\). In the general term of the binomial expansion:\\n\\\\[\\n\\\\binom{8}{k} \\\\left(\\\\frac{3}{5}x\\\\right)^{8-k} \\\\left(-\\\\frac{y}{2}\\\\right)^k\\n\\\\]\\n\\nTo get \\\\(x^2\\\\), we need \\\\(8 - k = 2\\\\), thus \\\\(k = 6\\\\).\\n\\nSubstituting \\\\(k = 6\\\\) into the expression:\\n\\\\[\\n\\\\binom{8}{6} \\\\left(\\\\frac{3}{5}x\\\\right)^{8-6} \\\\left(-\\\\frac{y}{2}\\\\right)^6 = \\\\binom{8}{6} \\\\left(\\\\frac{3}{5}x\\\\right)^2 \\\\left(-\\\\frac{y}{2}\\\\right)^6\\n\\\\]\\n\\nNow, we will compute each part of this expression.\\n\\n1. Calculate the binomial coefficient \\\\(\\\\binom{8}{6}\\\\).\\n2. Compute \\\\(\\\\left(\\\\frac{3}{5}\\\\right)^2\\\\).\\n3. Compute \\\\(\\\\left(-\\\\frac{y}{2}\\\\right)^6\\\\).\\n4. Combine everything together to get the coefficient of \\\\(x^2y^6\\\\).\\n\\nLet's compute these in Python.\\n```python\\nfrom math import comb\\n\\n# Given values\\nn = 8\\nk = 6\\n\\n# Calculate the binomial coefficient\\nbinom_coeff = comb(n, k)\\n\\n# Compute (3/5)^2\\na_term = (3/5)**2\\n\\n# Compute (-1/2)^6\\nb_term = (-1/2)**6\\n\\n# Combine terms to get the coefficient of x^2y^6\\ncoefficient = binom_coeff * a_term * b_term\\nprint(coefficient)\\n```\\n```output\\n0.1575\\n```\\nThe coefficient of \\\\(x^2y^6\\\\) in the expansion of \\\\(\\\\left(\\\\frac{3}{5}x - \\\\frac{y}{2}\\\\right)^8\\\\) is \\\\(0.1575\\\\). To express this as a common fraction, we recognize that:\\n\\n\\\\[ 0.1575 = \\\\frac{1575}{10000} = \\\\frac{63}{400} \\\\]\\n\\nThus, the coefficient can be expressed as:\\n\\n\\\\[\\n\\\\boxed{\\\\frac{63}{400}}\\n\\\\]\", 'messages': [{'content': 'What is the coefficient of $x^2y^6$ in the expansion of $\\\\left(\\\\frac{3}{5}x-\\\\frac{y}{2}\\\\right)^8$?  Express your answer as a common fraction.', 'role': 'user'}, {'content': \"To determine the coefficient of \\\\(x^2y^6\\\\) in the expansion of \\\\(\\\\left(\\\\frac{3}{5}x - \\\\frac{y}{2}\\\\right)^8\\\\), we can use the binomial theorem.\\n\\nThe binomial theorem states:\\n\\\\[\\n(a + b)^n = \\\\sum_{k=0}^{n} \\\\binom{n}{k} a^{n-k} b^k\\n\\\\]\\n\\nIn this case, \\\\(a = \\\\frac{3}{5}x\\\\), \\\\(b = -\\\\frac{y}{2}\\\\), and \\\\(n = 8\\\\).\\n\\nWe are interested in the term that contains \\\\(x^2y^6\\\\). In the general term of the binomial expansion:\\n\\\\[\\n\\\\binom{8}{k} \\\\left(\\\\frac{3}{5}x\\\\right)^{8-k} \\\\left(-\\\\frac{y}{2}\\\\right)^k\\n\\\\]\\n\\nTo get \\\\(x^2\\\\), we need \\\\(8 - k = 2\\\\), thus \\\\(k = 6\\\\).\\n\\nSubstituting \\\\(k = 6\\\\) into the expression:\\n\\\\[\\n\\\\binom{8}{6} \\\\left(\\\\frac{3}{5}x\\\\right)^{8-6} \\\\left(-\\\\frac{y}{2}\\\\right)^6 = \\\\binom{8}{6} \\\\left(\\\\frac{3}{5}x\\\\right)^2 \\\\left(-\\\\frac{y}{2}\\\\right)^6\\n\\\\]\\n\\nNow, we will compute each part of this expression.\\n\\n1. Calculate the binomial coefficient \\\\(\\\\binom{8}{6}\\\\).\\n2. Compute \\\\(\\\\left(\\\\frac{3}{5}\\\\right)^2\\\\).\\n3. Compute \\\\(\\\\left(-\\\\frac{y}{2}\\\\right)^6\\\\).\\n4. Combine everything together to get the coefficient of \\\\(x^2y^6\\\\).\\n\\nLet's compute these in Python.\\n```python\\nfrom math import comb\\n\\n# Given values\\nn = 8\\nk = 6\\n\\n# Calculate the binomial coefficient\\nbinom_coeff = comb(n, k)\\n\\n# Compute (3/5)^2\\na_term = (3/5)**2\\n\\n# Compute (-1/2)^6\\nb_term = (-1/2)**6\\n\\n# Combine terms to get the coefficient of x^2y^6\\ncoefficient = binom_coeff * a_term * b_term\\nprint(coefficient)\\n```\\n```output\\n0.1575\\n```\\nThe coefficient of \\\\(x^2y^6\\\\) in the expansion of \\\\(\\\\left(\\\\frac{3}{5}x - \\\\frac{y}{2}\\\\right)^8\\\\) is \\\\(0.1575\\\\). To express this as a common fraction, we recognize that:\\n\\n\\\\[ 0.1575 = \\\\frac{1575}{10000} = \\\\frac{63}{400} \\\\]\\n\\nThus, the coefficient can be expressed as:\\n\\n\\\\[\\n\\\\boxed{\\\\frac{63}{400}}\\n\\\\]\", 'role': 'assistant'}]}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "In the **DeepSeek-R1** training procedure, a specific system prompt was used to generate a conversational pipeline that includes reasoning steps. We'll adapt our dataset to follow this approach, where the model is guided to first think through the problem and then present its answer.\n",
        "\n",
        "The system prompt used is:\n",
        "\n",
        "```\n",
        "A conversation between User and Assistant. The user asks a question, and the Assistant solves it.\n",
        "The assistant first thinks about the reasoning process in the mind and then provides the user\n",
        "with the answer. The reasoning process and answer are enclosed within <think> </think> and\n",
        "<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>\n",
        "<answer> answer here </answer>. User: prompt. Assistant:\n",
        "```\n",
        "\n",
        "We will modify our dataset to follow this conversational format, prompting the LLM to generate both the reasoning steps and the final answer.\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "6isapXWue91d"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "SYSTEM_PROMPT = (\n",
        "    \"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant \"\n",
        "    \"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning \"\n",
        "    \"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., \"\n",
        "    \"<think> reasoning process here </think><answer> answer here </answer>\"\n",
        ")\n",
        "\n",
        "def make_conversation(example):\n",
        "    return {\n",
        "        \"prompt\": [\n",
        "            {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
        "            {\"role\": \"user\", \"content\": example[\"problem\"]},\n",
        "        ],\n",
        "    }\n",
        "\n",
        "train_dataset = train_dataset.map(make_conversation)\n",
        "test_dataset = test_dataset.map(make_conversation)"
      ],
      "metadata": {
        "id": "iXsh50jY_hQM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's take a look at an example:"
      ],
      "metadata": {
        "id": "T7gpvXeF0AUh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(train_dataset[0]['prompt'])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "s48vCAy3e1x0",
        "outputId": "64cf0145-078f-48bb-fde6-6c6d031d801d"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[{'content': 'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>', 'role': 'system'}, {'content': 'What is the coefficient of $x^2y^6$ in the expansion of $\\\\left(\\\\frac{3}{5}x-\\\\frac{y}{2}\\\\right)^8$?  Express your answer as a common fraction.', 'role': 'user'}]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We'll remove the `messages` and `problem` columns, as we only need the custom `prompt` column and `solution` to verify the generated answer.  "
      ],
      "metadata": {
        "id": "q6ijkZ3VmxA4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "train_dataset = train_dataset.remove_columns(['messages', 'problem'])\n",
        "print(train_dataset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EaY8lUYSHyhA",
        "outputId": "6d64ab57-1e69-4e85-b0fb-c8c3db885c89"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset({\n",
            "    features: ['solution', 'prompt'],\n",
            "    num_rows: 3622\n",
            "})\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YIZOIVEzQqNg"
      },
      "source": [
        "## 3. Post-Training the Base Model Using GRPO\n",
        "\n",
        "The diagram below highlights the main differences between **PPO** (Proximal Policy Optimization) and **GRPO** (Group Relative Policy Optimization), specifically the removal of the value model in GRPO. For more detailed information on the key differences, you can refer to the [full explanation here](https://www.philschmid.de/deepseek-r1)."
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "![ppo_grpo.jpeg]()"
      ],
      "metadata": {
        "id": "BbIwDpT8F_aq"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 3.1 Loading the Baseline Model\n",
        "\n",
        "To begin, we'll load [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the baseline model (`Policy Model` in the diagram above). With only 0.5 billion parameters, it is lightweight and fits within the available resources. However, for better results, a larger [alternative](https://x.com/jiayi_pirate/status/1882839487417561307) should be considered.  \n"
      ],
      "metadata": {
        "id": "D-UlkRzREf-J"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from transformers import AutoModelForCausalLM\n",
        "\n",
        "model_id = \"Qwen/Qwen2-0.5B-Instruct\"\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id,\n",
        "    torch_dtype=\"auto\",\n",
        "    device_map=\"auto\",\n",
        ")"
      ],
      "metadata": {
        "id": "qv02eazzEUeJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 3.2 Configuring LoRA\n",
        "\n",
        "Next, we will configure LoRA for model training. This technique will allow us to efficiently fine-tune the model with a reduced number of parameters, enabling faster and more resource-efficient training."
      ],
      "metadata": {
        "id": "OksUs_tWEXvR"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ITmkRHWCKYjf",
        "outputId": "4229a71a-2381-402f-a1c4-8f2492150cc5"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "trainable params: 540,672 || all params: 494,573,440 || trainable%: 0.1093\n"
          ]
        }
      ],
      "source": [
        "from peft import LoraConfig, get_peft_model\n",
        "\n",
        "lora_config = LoraConfig(\n",
        "    task_type=\"CAUSAL_LM\",\n",
        "    r=8,\n",
        "    lora_alpha=32,\n",
        "    lora_dropout=0.1,\n",
        "    target_modules=[\"q_proj\", \"v_proj\"],\n",
        ")\n",
        "\n",
        "model = get_peft_model(model, lora_config)\n",
        "\n",
        "model.print_trainable_parameters()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 3.3 Loading Reward Functions\n",
        "\n",
        "For the reward component of the system, we can use either pretrained reward models or reward functions defined directly in code. For training, the DeepSeek-R1 authors used an accuracy-based reward model evaluates whether the response is correct, alongside a format-based reward that ensures the model places its reasoning process between `<think> </think>` tags. You can find more details [here](https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py). We can simply define and implement these reward functions as generic Python functions.\n",
        "\n",
        "In this case, we will utilize these reward functions:\n",
        "\n",
        "1. **Format Enforcement:** Ensures that the generation follows a specific format using `<think> </think> <answer> </answer>` tags for reasoning.  "
      ],
      "metadata": {
        "id": "4M6prmhAEodm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import re\n",
        "def format_reward(completions, **kwargs):\n",
        "    \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
        "    pattern = r\"^<think>.*?</think>\\s*<answer>.*?</answer>$\"\n",
        "    completion_contents = [completion[0][\"content\"] for completion in completions]\n",
        "    matches = [re.match(pattern, content) for content in completion_contents]\n",
        "    rewards_list = [1.0 if match else 0.0 for match in matches]\n",
        "    return [1.0 if match else 0.0 for match in matches]"
      ],
      "metadata": {
        "id": "BE7ZgN_sDPNg"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "2. **Solution Accuracy:** Verifies whether the solution to the problem is correct."
      ],
      "metadata": {
        "id": "nOQMzHHDoE18"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from math_verify import LatexExtractionConfig, parse, verify\n",
        "def accuracy_reward(completions, **kwargs):\n",
        "    \"\"\"Reward function that checks if the completion is the same as the ground truth.\"\"\"\n",
        "    solutions = kwargs['solution']\n",
        "    completion_contents = [completion[0][\"content\"] for completion in completions]\n",
        "    rewards = []\n",
        "    for content, solution in zip(completion_contents, solutions):\n",
        "        gold_parsed = parse(solution, extraction_mode=\"first_match\", extraction_config=[LatexExtractionConfig()])\n",
        "        answer_parsed = parse(content, extraction_mode=\"first_match\", extraction_config=[LatexExtractionConfig()])\n",
        "        if len(gold_parsed) != 0:\n",
        "            try:\n",
        "                rewards.append(float(verify(answer_parsed, gold_parsed)))\n",
        "            except Exception:\n",
        "                rewards.append(0.0)\n",
        "        else:\n",
        "            rewards.append(1.0)\n",
        "    return rewards"
      ],
      "metadata": {
        "id": "P3VIGZL4FLxA"
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### 3.4 Configuring GRPO Training Parameters\n",
        "\n",
        "Next, let's configure the training parameters for GRPO. We recommend experimenting with the `max_completion_length`, `num_generations`, and `max_prompt_length` parameters (refer to the image at the beginning for details about each of them).\n",
        "\n",
        "To keep things simple, we’ll start by training for just one epoch and reducing the `max_completion_length`, `num_generations`, and `max_prompt_length` from their default values."
      ],
      "metadata": {
        "id": "qW_3r8T1EtNg"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "SbqX1pQUKaSM"
      },
      "outputs": [],
      "source": [
        "from trl import GRPOConfig\n",
        "\n",
        "# Configure training arguments using GRPOConfig\n",
        "training_args = GRPOConfig(\n",
        "    output_dir=\"Qwen2-0.5B-GRPO-test\",\n",
        "    learning_rate=1e-5,\n",
        "    remove_unused_columns=False, # to access the solution column in accuracy_reward\n",
        "    gradient_accumulation_steps=16,\n",
        "    num_train_epochs=1,\n",
        "    bf16=True,\n",
        "\n",
        "    # Parameters that control de data preprocessing\n",
        "    max_completion_length=64, # default: 256\n",
        "    num_generations=4, # default: 8\n",
        "    max_prompt_length=128, # default: 512\n",
        "\n",
        "    # Parameters related to reporting and saving\n",
        "    report_to=[\"tensorboard\"],\n",
        "    logging_steps=10,\n",
        "    push_to_hub=True,\n",
        "    save_strategy=\"steps\",\n",
        "    save_steps=10,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pOUrD9P-y-Kf"
      },
      "source": [
        "### 3.5 Training the Model 🏃\n",
        "\n",
        "Now, let's configure the trainer and start training the model!\n",
        "\n",
        "In this case, we pass the two reward functions we previously defined to the trainer\n",
        "\n",
        "Below, you'll find a diagram of the training procedure we'll be reproducing, which is sourced from the [Open-R1 project](https://github.com/huggingface/open-r1)."
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "![image.png]()"
      ],
      "metadata": {
        "id": "xxGhmAx-ZxWQ"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "k_jk-U7ULYtA"
      },
      "outputs": [],
      "source": [
        "from trl import GRPOTrainer\n",
        "\n",
        "trainer = GRPOTrainer(\n",
        "    model=model,\n",
        "    reward_funcs=[format_reward, accuracy_reward],\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NlDsh4WvWCx0"
      },
      "source": [
        "Time to train the model! 🎉"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.train()"
      ],
      "metadata": {
        "id": "IniPLesA13Qd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's save the results 💾"
      ],
      "metadata": {
        "id": "z7_y1x7E1JY9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.save_model(training_args.output_dir)\n",
        "trainer.push_to_hub(dataset_name=dataset_id)"
      ],
      "metadata": {
        "id": "Cazf4AB2nbRT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Below, you can review the Tensorboard results for the training. They look promising!"
      ],
      "metadata": {
        "id": "CqUFU6t71iNi"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "![image.png]()"
      ],
      "metadata": {
        "id": "1qfCDmaL1XvL"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 4. Check the Model Performance\n",
        "\n",
        "We've kept things simple so far, but now let's check if the model has already learned to reason. We'll load the saved model and run an evaluation on a test sample."
      ],
      "metadata": {
        "id": "MBv5BK1a1N-0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer\n",
        "\n",
        "model_id = \"sergiopaniego/Qwen2-0.5B-GRPO\"\n",
        "trained_model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id,\n",
        "    torch_dtype=\"auto\",\n",
        "    device_map=\"auto\",\n",
        ")\n",
        "trained_tokenizer = AutoTokenizer.from_pretrained(model_id)"
      ],
      "metadata": {
        "id": "sHOjAE1-PB1D"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's check one sample from the test set!"
      ],
      "metadata": {
        "id": "vuldkJXycdn7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(test_dataset['prompt'][0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uhMWeOMhIC8t",
        "outputId": "3dd9075f-2884-43c3-a16f-2a8c2c3575a3"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[{'content': 'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>', 'role': 'system'}, {'content': \"In 1988, a person's age was equal to the sum of the digits of their birth year. How old was this person?\", 'role': 'user'}]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We'll create a function to interact with the model. In addition to generating the answer, we'll measure the inference duration and count the number of generated tokens. This will give us insights into how much the model has reasoned during generation."
      ],
      "metadata": {
        "id": "Y1b4udjecFUF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "\n",
        "def generate_with_reasoning(prompt):\n",
        "  # Build the prompt from the dataset\n",
        "  prompt = \" \".join(entry['content'] for entry in prompt)\n",
        "\n",
        "  # Tokenize and move to the same device as the model\n",
        "  inputs = trained_tokenizer(prompt, return_tensors=\"pt\").to(trained_model.device)\n",
        "\n",
        "  # Generate text without gradients\n",
        "  start_time = time.time()\n",
        "  with torch.no_grad():\n",
        "      output_ids = trained_model.generate(**inputs, max_length=500)\n",
        "  end_time = time.time()\n",
        "\n",
        "  # Decode and extract model response\n",
        "  generated_text = trained_tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
        "\n",
        "  # Get inference time\n",
        "  inference_duration = end_time - start_time\n",
        "\n",
        "  # Get number of generated tokens\n",
        "  num_input_tokens = inputs['input_ids'].shape[1]\n",
        "  num_generated_tokens = output_ids.shape[1] - num_input_tokens\n",
        "\n",
        "  return generated_text, inference_duration, num_generated_tokens"
      ],
      "metadata": {
        "id": "SsT_ujJZIUOf"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's generate the answer for that test sample!"
      ],
      "metadata": {
        "id": "3V7XuDRWcaNr"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = test_dataset['prompt'][0]\n",
        "generated_text, inference_duration, num_generated_tokens = generate_with_reasoning(prompt)\n",
        "print(generated_text)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5maa5GwNMW68",
        "outputId": "a71a259d-693b-43e8-f194-62fe1d929ca2"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer> In 1988, a person's age was equal to the sum of the digits of their birth year. How old was this person?<think>\n",
            "The reasoning process is that if the sum of the digits of the birth year is equal to the person's age, then the person must have been born in a given year.\n",
            "\n",
            "<think>\n",
            "The answer is: 1988\n",
            "</think>\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "The model already demonstrates the ability to generate the correct `<think>` and `<answer>` tags, even though the solution itself is incorrect.\n",
        "\n",
        "Given the inference time and the number of generated tokens, this approach shows potential benefits:"
      ],
      "metadata": {
        "id": "-v-p8p8mckTr"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Inference time: {inference_duration:.2f} seconds\")\n",
        "print(f\"Generated tokens: {num_generated_tokens}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zW4y6HmjKGQ2",
        "outputId": "ee0a853e-a1a0-4588-8dec-478602fd7094"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Inference time: 2.09 seconds\n",
            "Generated tokens: 55\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let’s review the generated response to better visualize this behavior:"
      ],
      "metadata": {
        "id": "CkFAbH_7dFtY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "prompt_text = \" \".join(entry['content'] for entry in prompt)\n",
        "response_text = generated_text[len(prompt_text):].strip()\n",
        "print(response_text)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gasQhrqCIzNq",
        "outputId": "3532ec76-b887-4e9d-aec9-3689f2bd3c0f"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<think>\n",
            "The reasoning process is that if the sum of the digits of the birth year is equal to the person's age, then the person must have been born in a given year.\n",
            "\n",
            "<think>\n",
            "The answer is: 1988\n",
            "</think>\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We observe that the model demonstrates some reasoning capabilities, although these are limited. This can be attributed to several factors: the use of a small model, a limited subset of the dataset, and a short training duration to keep the process simple and practical for a notebook environment.\n",
        "\n",
        "Additionally, the complexity of the dataset plays a role. Simplifying the problem might yield better results, as demonstrated [here](https://www.philschmid.de/mini-deepseek-r1).\n",
        "\n",
        "Despite these constraints, this technique shows great promise. The release of DeepSeek-R1 and the adoption of this training approach could lead to significant breakthroughs in the coming months!"
      ],
      "metadata": {
        "id": "l9jRqV8Uo9iX"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 5. Continuing Your Learning Journey 🧑‍🎓\n",
        "\n",
        "As you can see, this is just the beginning of exploring the GRPO trainer and the DeepSeek R1 model. If you’re eager to dive deeper, be sure to explore the following resources linked in the notebook, as well as these additional materials:\n",
        "\n",
        "* [DeepSeek-R1's repo](https://github.com/deepseek-ai/DeepSeek-R1/)\n",
        "* [DeepSeek-R1's paper](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf)\n",
        "* [Open reproduction of DeepSeek-R1](https://github.com/huggingface/open-r1/)\n",
        "* [GRPO TRL trainer](https://huggingface.co/docs/trl/main/en/grpo_trainer)\n",
        "* [Phil Schmid’s DeepSeek-R1 Blog Post](https://www.philschmid.de/deepseek-r1)\n",
        "* [Phil Schmid’s mini DeepSeek-R1 Blog Post](https://www.philschmid.de/mini-deepseek-r1)\n",
        "* [Illustrated DeepSeek-R1](https://newsletter.languagemodels.co/p/the-illustrated-deepseek-r1)\n",
        "* [The LM Book’s DeepSeek-R1 Article](https://thelmbook.com/articles/#!./DeepSeek-R1.md)\n",
        "\n",
        "Happy learning and experimenting! 🚀\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "Uh4inHFUFIku"
      }
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}